{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Rulefit demo - Titanic Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## H2O Rulefit algorithm\n",
    "\n",
    "Rulefit algorithm combines tree ensembles and linear models to take advantage of both methods: a tree ensemble accuracy and a linear model interpretability. The general algorithm fits a tree ensebmle to the data, builds a rule ensemble by traversing each tree, evaluates the rules on the data to build a rule feature set and fits a sparse linear model (LASSO) to the rule feature set joined with the original feature set.\n",
    "\n",
    "For more information, refer to: http://statweb.stanford.edu/~jhf/ftp/RuleFit.pdf by Jerome H. Friedman and Bogden E. Popescu.\n",
    "\n",
    "## Demo example\n",
    "\n",
    "We will train a rulefit model to predict the rules defining whether or not someone will survive:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checking whether there is an H2O instance running at http://localhost:54321 . connected.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div style=\"overflow:auto\"><table style=\"width:50%\"><tr><td>H2O_cluster_uptime:</td>\n",
       "<td>4 mins 19 secs</td></tr>\n",
       "<tr><td>H2O_cluster_timezone:</td>\n",
       "<td>Europe/Prague</td></tr>\n",
       "<tr><td>H2O_data_parsing_timezone:</td>\n",
       "<td>UTC</td></tr>\n",
       "<tr><td>H2O_cluster_version:</td>\n",
       "<td>3.34.0.99999</td></tr>\n",
       "<tr><td>H2O_cluster_version_age:</td>\n",
       "<td>17 minutes </td></tr>\n",
       "<tr><td>H2O_cluster_name:</td>\n",
       "<td>zuzanaolajcova</td></tr>\n",
       "<tr><td>H2O_cluster_total_nodes:</td>\n",
       "<td>1</td></tr>\n",
       "<tr><td>H2O_cluster_free_memory:</td>\n",
       "<td>3.546 Gb</td></tr>\n",
       "<tr><td>H2O_cluster_total_cores:</td>\n",
       "<td>12</td></tr>\n",
       "<tr><td>H2O_cluster_allowed_cores:</td>\n",
       "<td>12</td></tr>\n",
       "<tr><td>H2O_cluster_status:</td>\n",
       "<td>locked, healthy</td></tr>\n",
       "<tr><td>H2O_connection_url:</td>\n",
       "<td>http://localhost:54321</td></tr>\n",
       "<tr><td>H2O_connection_proxy:</td>\n",
       "<td>{\"http\": null, \"https\": null}</td></tr>\n",
       "<tr><td>H2O_internal_security:</td>\n",
       "<td>False</td></tr>\n",
       "<tr><td>H2O_API_Extensions:</td>\n",
       "<td>Algos, AutoML, Core V3, TargetEncoder, Core V4</td></tr>\n",
       "<tr><td>Python_version:</td>\n",
       "<td>3.8.1 final</td></tr></table></div>"
      ],
      "text/plain": [
       "--------------------------  ----------------------------------------------\n",
       "H2O_cluster_uptime:         4 mins 19 secs\n",
       "H2O_cluster_timezone:       Europe/Prague\n",
       "H2O_data_parsing_timezone:  UTC\n",
       "H2O_cluster_version:        3.34.0.99999\n",
       "H2O_cluster_version_age:    17 minutes\n",
       "H2O_cluster_name:           zuzanaolajcova\n",
       "H2O_cluster_total_nodes:    1\n",
       "H2O_cluster_free_memory:    3.546 Gb\n",
       "H2O_cluster_total_cores:    12\n",
       "H2O_cluster_allowed_cores:  12\n",
       "H2O_cluster_status:         locked, healthy\n",
       "H2O_connection_url:         http://localhost:54321\n",
       "H2O_connection_proxy:       {\"http\": null, \"https\": null}\n",
       "H2O_internal_security:      False\n",
       "H2O_API_Extensions:         Algos, AutoML, Core V3, TargetEncoder, Core V4\n",
       "Python_version:             3.8.1 final\n",
       "--------------------------  ----------------------------------------------"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import h2o\n",
    "from h2o.estimators import H2ORuleFitEstimator, H2ORandomForestEstimator\n",
    "\n",
    "# init h2o cluster\n",
    "h2o.init()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parse progress: |████████████████████████████████████████████████████████████████| (done) 100%\n"
     ]
    }
   ],
   "source": [
    "df = h2o.import_file(\"https://s3.amazonaws.com/h2o-public-test-data/smalldata/gbm_test/titanic.csv\",\n",
    "                       col_types={'pclass': \"enum\", 'survived': \"enum\"})\n",
    "x =  [\"age\", \"sibsp\", \"parch\", \"sex\", \"pclass\"]\n",
    "\n",
    "# Split the dataset into train and test\n",
    "train, test = df.split_frame(ratios=[.8], seed=1234)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using the `algorithm` parameter, a user can set whether algorithm will use DRF or GBM to fit a tree enseble. \n",
    "\n",
    "Using the `min_rule_length` and `max_rule_length` parameters, a user can set interval of tree enseble depths to be fitted. The bigger this interval is, the more tree ensembles will be fitted (1 per each depth) and the bigger the rule feature set will be.\n",
    "\n",
    "Using the `max_num_rules` parameter, the maximum number of rules to return can be set.\n",
    "\n",
    "Using the `model_type` parameter, the type of base learners in the enseble can be set.\n",
    "\n",
    "Using the `rule_generation_ntrees` parameter, the number of trees for tree enseble can be set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rulefit Model Build progress: |██████████████████████████████████████████████████| (done) 100%\n",
      "Model Details\n",
      "=============\n",
      "H2ORuleFitEstimator :  RuleFit\n",
      "Model Key:  RuleFit_model_python_1636562504000_1\n",
      "\n",
      "\n",
      "Rulefit Model Summary: \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>family</th>\n",
       "      <th>link</th>\n",
       "      <th>regularization</th>\n",
       "      <th>number_of_predictors_total</th>\n",
       "      <th>number_of_active_predictors</th>\n",
       "      <th>number_of_iterations</th>\n",
       "      <th>rule_ensemble_size</th>\n",
       "      <th>number_of_trees</th>\n",
       "      <th>number_of_internal_trees</th>\n",
       "      <th>min_depth</th>\n",
       "      <th>max_depth</th>\n",
       "      <th>mean_depth</th>\n",
       "      <th>min_leaves</th>\n",
       "      <th>max_leaves</th>\n",
       "      <th>mean_leaves</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td></td>\n",
       "      <td>binomial</td>\n",
       "      <td>logit</td>\n",
       "      <td>Lasso (lambda = 0.01292 )</td>\n",
       "      <td>20784</td>\n",
       "      <td>8</td>\n",
       "      <td>3</td>\n",
       "      <td>20776.0</td>\n",
       "      <td>500.0</td>\n",
       "      <td>500.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>10.0</td>\n",
       "      <td>5.5</td>\n",
       "      <td>0.0</td>\n",
       "      <td>135.0</td>\n",
       "      <td>41.552</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       family   link             regularization  number_of_predictors_total  \\\n",
       "0    binomial  logit  Lasso (lambda = 0.01292 )                       20784   \n",
       "\n",
       "  number_of_active_predictors  number_of_iterations  rule_ensemble_size  \\\n",
       "0                           8                     3             20776.0   \n",
       "\n",
       "   number_of_trees  number_of_internal_trees  min_depth  max_depth  \\\n",
       "0            500.0                     500.0        0.0       10.0   \n",
       "\n",
       "   mean_depth  min_leaves  max_leaves  mean_leaves  \n",
       "0         5.5         0.0       135.0       41.552  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "ModelMetricsBinomialGLM: rulefit\n",
      "** Reported on train data. **\n",
      "\n",
      "MSE: 0.14668202166384883\n",
      "RMSE: 0.3829908897922362\n",
      "LogLoss: 0.4616331658988569\n",
      "Null degrees of freedom: 1053\n",
      "Residual degrees of freedom: 1045\n",
      "Null deviance: 1405.0919048764067\n",
      "Residual deviance: 973.1227137147903\n",
      "AIC: 991.1227137147903\n",
      "AUC: 0.8361042692939246\n",
      "AUCPR: 0.7904193564939762\n",
      "Gini: 0.6722085385878491\n",
      "\n",
      "Confusion Matrix (Act/Pred) for max f1 @ threshold = 0.44132286664639514: \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>Error</th>\n",
       "      <th>Rate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>526.0</td>\n",
       "      <td>122.0</td>\n",
       "      <td>0.1883</td>\n",
       "      <td>(122.0/648.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>106.0</td>\n",
       "      <td>300.0</td>\n",
       "      <td>0.2611</td>\n",
       "      <td>(106.0/406.0)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Total</td>\n",
       "      <td>632.0</td>\n",
       "      <td>422.0</td>\n",
       "      <td>0.2163</td>\n",
       "      <td>(228.0/1054.0)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              0      1   Error             Rate\n",
       "0      0  526.0  122.0  0.1883    (122.0/648.0)\n",
       "1      1  106.0  300.0  0.2611    (106.0/406.0)\n",
       "2  Total  632.0  422.0  0.2163   (228.0/1054.0)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Maximum Metrics: Maximum metrics at their respective thresholds\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>threshold</th>\n",
       "      <th>value</th>\n",
       "      <th>idx</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>max f1</td>\n",
       "      <td>0.441323</td>\n",
       "      <td>0.724638</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>max f2</td>\n",
       "      <td>0.160033</td>\n",
       "      <td>0.783832</td>\n",
       "      <td>7.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>max f0point5</td>\n",
       "      <td>0.809013</td>\n",
       "      <td>0.774478</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>max accuracy</td>\n",
       "      <td>0.523805</td>\n",
       "      <td>0.790323</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>max precision</td>\n",
       "      <td>0.809013</td>\n",
       "      <td>0.919048</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>max recall</td>\n",
       "      <td>0.156308</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>max specificity</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>0.973765</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>max absolute_mcc</td>\n",
       "      <td>0.523805</td>\n",
       "      <td>0.550968</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>max min_per_class_accuracy</td>\n",
       "      <td>0.441323</td>\n",
       "      <td>0.738916</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>max mean_per_class_accuracy</td>\n",
       "      <td>0.441323</td>\n",
       "      <td>0.775322</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>max tns</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>631.000000</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>max fns</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>217.000000</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>max fps</td>\n",
       "      <td>0.156308</td>\n",
       "      <td>648.000000</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>max tps</td>\n",
       "      <td>0.156308</td>\n",
       "      <td>406.000000</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>max tnr</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>0.973765</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>max fnr</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>0.534483</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>max fpr</td>\n",
       "      <td>0.156308</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>max tpr</td>\n",
       "      <td>0.156308</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                         metric  threshold       value  idx\n",
       "0                        max f1   0.441323    0.724638  3.0\n",
       "1                        max f2   0.160033    0.783832  7.0\n",
       "2                  max f0point5   0.809013    0.774478  1.0\n",
       "3                  max accuracy   0.523805    0.790323  2.0\n",
       "4                 max precision   0.809013    0.919048  1.0\n",
       "5                    max recall   0.156308    1.000000  8.0\n",
       "6               max specificity   0.855041    0.973765  0.0\n",
       "7              max absolute_mcc   0.523805    0.550968  2.0\n",
       "8    max min_per_class_accuracy   0.441323    0.738916  3.0\n",
       "9   max mean_per_class_accuracy   0.441323    0.775322  3.0\n",
       "10                      max tns   0.855041  631.000000  0.0\n",
       "11                      max fns   0.855041  217.000000  0.0\n",
       "12                      max fps   0.156308  648.000000  8.0\n",
       "13                      max tps   0.156308  406.000000  8.0\n",
       "14                      max tnr   0.855041    0.973765  0.0\n",
       "15                      max fnr   0.855041    0.534483  0.0\n",
       "16                      max fpr   0.156308    1.000000  8.0\n",
       "17                      max tpr   0.156308    1.000000  8.0"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Gains/Lift Table: Avg response rate: 38.52 %, avg score: 38.52 %\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>group</th>\n",
       "      <th>cumulative_data_fraction</th>\n",
       "      <th>lower_threshold</th>\n",
       "      <th>lift</th>\n",
       "      <th>cumulative_lift</th>\n",
       "      <th>response_rate</th>\n",
       "      <th>score</th>\n",
       "      <th>cumulative_response_rate</th>\n",
       "      <th>cumulative_score</th>\n",
       "      <th>capture_rate</th>\n",
       "      <th>cumulative_capture_rate</th>\n",
       "      <th>gain</th>\n",
       "      <th>cumulative_gain</th>\n",
       "      <th>kolmogorov_smirnov</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0.195446</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>2.381821</td>\n",
       "      <td>2.381821</td>\n",
       "      <td>0.917476</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>0.917476</td>\n",
       "      <td>0.855041</td>\n",
       "      <td>0.465517</td>\n",
       "      <td>0.465517</td>\n",
       "      <td>138.182123</td>\n",
       "      <td>138.182123</td>\n",
       "      <td>0.439283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>0.348197</td>\n",
       "      <td>0.523805</td>\n",
       "      <td>1.402839</td>\n",
       "      <td>1.952350</td>\n",
       "      <td>0.540373</td>\n",
       "      <td>0.530891</td>\n",
       "      <td>0.752044</td>\n",
       "      <td>0.712839</td>\n",
       "      <td>0.214286</td>\n",
       "      <td>0.679803</td>\n",
       "      <td>40.283940</td>\n",
       "      <td>95.234963</td>\n",
       "      <td>0.539371</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>0.400380</td>\n",
       "      <td>0.414525</td>\n",
       "      <td>1.132826</td>\n",
       "      <td>1.845540</td>\n",
       "      <td>0.436364</td>\n",
       "      <td>0.441323</td>\n",
       "      <td>0.710900</td>\n",
       "      <td>0.677452</td>\n",
       "      <td>0.059113</td>\n",
       "      <td>0.738916</td>\n",
       "      <td>13.282579</td>\n",
       "      <td>84.553965</td>\n",
       "      <td>0.550645</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>0.528463</td>\n",
       "      <td>0.307335</td>\n",
       "      <td>0.788433</td>\n",
       "      <td>1.589329</td>\n",
       "      <td>0.303704</td>\n",
       "      <td>0.307335</td>\n",
       "      <td>0.612208</td>\n",
       "      <td>0.587746</td>\n",
       "      <td>0.100985</td>\n",
       "      <td>0.839901</td>\n",
       "      <td>-21.156723</td>\n",
       "      <td>58.932883</td>\n",
       "      <td>0.506568</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.156308</td>\n",
       "      <td>0.339525</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.130785</td>\n",
       "      <td>0.158208</td>\n",
       "      <td>0.385199</td>\n",
       "      <td>0.385203</td>\n",
       "      <td>0.160099</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>-66.047517</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   group  cumulative_data_fraction  lower_threshold      lift  \\\n",
       "0      1                  0.195446         0.855041  2.381821   \n",
       "1      2                  0.348197         0.523805  1.402839   \n",
       "2      3                  0.400380         0.414525  1.132826   \n",
       "3      4                  0.528463         0.307335  0.788433   \n",
       "4      5                  1.000000         0.156308  0.339525   \n",
       "\n",
       "   cumulative_lift  response_rate     score  cumulative_response_rate  \\\n",
       "0         2.381821       0.917476  0.855041                  0.917476   \n",
       "1         1.952350       0.540373  0.530891                  0.752044   \n",
       "2         1.845540       0.436364  0.441323                  0.710900   \n",
       "3         1.589329       0.303704  0.307335                  0.612208   \n",
       "4         1.000000       0.130785  0.158208                  0.385199   \n",
       "\n",
       "   cumulative_score  capture_rate  cumulative_capture_rate        gain  \\\n",
       "0          0.855041      0.465517                 0.465517  138.182123   \n",
       "1          0.712839      0.214286                 0.679803   40.283940   \n",
       "2          0.677452      0.059113                 0.738916   13.282579   \n",
       "3          0.587746      0.100985                 0.839901  -21.156723   \n",
       "4          0.385203      0.160099                 1.000000  -66.047517   \n",
       "\n",
       "   cumulative_gain  kolmogorov_smirnov  \n",
       "0       138.182123            0.439283  \n",
       "1        95.234963            0.539371  \n",
       "2        84.553965            0.550645  \n",
       "3        58.932883            0.506568  \n",
       "4         0.000000            0.000000  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rfit = H2ORuleFitEstimator(algorithm=\"drf\", \n",
    "                               min_rule_length=1, \n",
    "                               max_rule_length=10, \n",
    "                               max_num_rules=100, \n",
    "                               model_type=\"rules_and_linear\",\n",
    "                               rule_generation_ntrees=50,\n",
    "                               seed=1234)\n",
    "rfit.train(training_frame=train, x=x, y=\"survived\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The output for the Rulefit model includes:\n",
    "    - model parameters\n",
    "    - rule importences in tabular form\n",
    "    - training and validation metrics of the underlying linear model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Rule Importance: \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>variable</th>\n",
       "      <th>coefficient</th>\n",
       "      <th>rule</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td></td>\n",
       "      <td>M2T21N13</td>\n",
       "      <td>1.298409e+00</td>\n",
       "      <td>(sex in {female}) &amp; (sibsp &lt; 3.5 or sibsp is NA) &amp; (pclass in {1, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td></td>\n",
       "      <td>M2T23N21</td>\n",
       "      <td>-8.453746e-01</td>\n",
       "      <td>(sex in {male} or sex is NA) &amp; (pclass in {2, 3} or pclass is NA) ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td></td>\n",
       "      <td>M1T0N7</td>\n",
       "      <td>3.809983e-01</td>\n",
       "      <td>(pclass in {1, 2}) &amp; (sex in {female})</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td></td>\n",
       "      <td>M1T28N10</td>\n",
       "      <td>-3.448192e-01</td>\n",
       "      <td>(sex in {male} or sex is NA) &amp; (age &gt;= 13.496771812438965 or age i...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td></td>\n",
       "      <td>M1T23N7</td>\n",
       "      <td>3.310857e-01</td>\n",
       "      <td>(sex in {female}) &amp; (sibsp &lt; 2.5 or sibsp is NA)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td></td>\n",
       "      <td>M1T37N10</td>\n",
       "      <td>-2.319945e-01</td>\n",
       "      <td>(sex in {male} or sex is NA) &amp; (age &gt;= 14.977890968322754 or age i...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td></td>\n",
       "      <td>M4T3N45</td>\n",
       "      <td>-2.797404e-02</td>\n",
       "      <td>(sex in {male} or sex is NA) &amp; (pclass in {2, 3} or pclass is NA) ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td></td>\n",
       "      <td>M1T1N7</td>\n",
       "      <td>2.887806e-14</td>\n",
       "      <td>(pclass in {1, 2}) &amp; (sex in {female})</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     variable   coefficient  \\\n",
       "0    M2T21N13  1.298409e+00   \n",
       "1    M2T23N21 -8.453746e-01   \n",
       "2      M1T0N7  3.809983e-01   \n",
       "3    M1T28N10 -3.448192e-01   \n",
       "4     M1T23N7  3.310857e-01   \n",
       "5    M1T37N10 -2.319945e-01   \n",
       "6     M4T3N45 -2.797404e-02   \n",
       "7      M1T1N7  2.887806e-14   \n",
       "\n",
       "                                                                    rule  \n",
       "0  (sex in {female}) & (sibsp < 3.5 or sibsp is NA) & (pclass in {1, ...  \n",
       "1  (sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) ...  \n",
       "2                                 (pclass in {1, 2}) & (sex in {female})  \n",
       "3  (sex in {male} or sex is NA) & (age >= 13.496771812438965 or age i...  \n",
       "4                       (sex in {female}) & (sibsp < 2.5 or sibsp is NA)  \n",
       "5  (sex in {male} or sex is NA) & (age >= 14.977890968322754 or age i...  \n",
       "6  (sex in {male} or sex is NA) & (pclass in {2, 3} or pclass is NA) ...  \n",
       "7                                 (pclass in {1, 2}) & (sex in {female})  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import display\n",
    "display(rfit.rule_importance())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are several rules that can be recapped as:\n",
    "\n",
    "### Higgest Likelihood of Survival:\n",
    "1. women in class 1 or 2 with 3 siblings/spouses aboard or less\n",
    "2. women in class 1 or 2\n",
    "3. women with 2 siblings/spouses aboard or less\n",
    "\n",
    "### Lowest Likelihood of Survival:\n",
    "1. male in class 2 or 3 of age >= 9.4\n",
    "2. male of age >= 13.4\n",
    "3. male of age >= 14.8\n",
    "4. male in class 2 or 3 with no parents/children aboard of age between 14 to 61\n",
    "\n",
    "Note: The rules are additive. That means that if a passenger is described by multiple rules, their probability is added together from those rules."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "h2o3pyenv",
   "language": "python",
   "name": "h2o3pyenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
